import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimage
import numpy as np
import os
from glob import glob
import random
from sklearn.metrics import confusion_matrix
import pickle

# os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'





def generate_data2(lm, N_tr, N_ts, tr_data, tr_label, ts_data, ts_label):
    N_aug=2000
    aug_ratio=N_aug/N_tr
    N2 = int(N_tr * lm)
    N1 = int(N_tr * (1 - lm))
    N4 = int(N_ts * 0.5)
    N3 = int(N_ts * 0.5)
    index_animal_tr=[]
    index_animal_ts=[]

    for j in range(6):
        index_animal_tr=index_animal_tr+[i for i in range(tr_label.shape[0]) if tr_label[i,0]==j+2]
        index_animal_ts=index_animal_ts+[i for i in range(ts_label.shape[0]) if ts_label[i,0]==j+2]
    index_ap_tr=[i for i in range(tr_label.shape[0]) if tr_label[i,0]==0]
    index_ap_ts=[i for i in range(ts_label.shape[0]) if ts_label[i,0]==0]
    #index_ap_ini=index_ap_tr[0:4000]
    #index_ap_aug=list(np.random.choice(index_ap_ini, N2-len(index_ap_ini)))
    #train_aug=np.zeros([N2-len(index_ap_ini), 32, 32, 3])
    #index_ap_aug = list(np.random.choice(index_ap_ini, N_aug))
    #train_aug = np.zeros([N_aug, 32, 32, 3])
    #i=0
    #for ind in index_ap_aug:
        #seed=(random.randint(1,2000),random.randint(1,2000))
        #img=tf.image.stateless_random_crop(tr_data[ind],size=(5,5,3), seed=seed)
        #train_aug[i,:,:,:]=np.array(tf.image.resize(img, tr_data[ind].shape[0:2]))
        #train_aug[i, :, :, :] = np.array(tr_data[ind])+np.random.normal(0,300, tr_data[ind].shape)
        #train_aug[i,:,:,:]=np.array(tf.image.stateless_random_saturation(tr_data[ind],0.6, 1, seed=seed))
        #train_aug[i, :, :, :] = np.array(tf.image.flip_left_right(tr_data[ind]))
        #train_aug[i,:,:,:]=np.array(tf.keras.layers.RandomRotation(factor=(0.0,0.1))(tr_data[ind]))
        #i=i+1
    train_index_animal=list(random.sample(index_animal_tr, N1))
    train_index_ap = list(random.sample(index_ap_tr, N2))
    test_index_animal=list(random.sample(index_animal_ts, N3))
    test_index_ap=list(random.sample(index_ap_ts, N4))
    train_animal_label=np.ones([N1,1], dtype=np.uint8)
    test_animal_label=np.ones([N3,1], dtype=np.uint8)

    tr_1=tr_data[train_index_animal]/255-np.mean(tr_data[train_index_animal],axis=0)/255
    tr_2=tr_data[train_index_ap]
    tr_2=tr_2/255-np.mean(tr_2)/255
    #v=np.mean(np.var(tr_2.reshape(4000+N_aug,3072), axis=0))
    #print(v)
    train_data=np.concatenate((tr_1,tr_2), axis=0)

    #train_index=list(random.sample(index_animal_tr, N1))+list(random.sample(index_ap_tr, N2))
    #train_data=tr_data[train_index]
    train_label_ap=tr_label[list(train_index_ap)]
    train_label=np.concatenate((train_animal_label,train_label_ap), axis=0)
    ts_1=ts_data[test_index_animal]/255-np.mean(ts_data[test_index_animal], axis=0)/255
    ts_ap_index=list(random.sample(index_ap_ts, N4))
    ts_2=ts_data[test_index_ap]
    ts_2=ts_2/255-np.mean(ts_2, axis=0)/255

    #train_label=tr_label[train_index]
    test_data=np.concatenate((ts_1,ts_2), axis=0)
    test_label=np.concatenate((test_animal_label, ts_label[ts_ap_index]))
    #mean_tr=np.mean(train_data, axis=0)
    return train_data, train_label, test_data, test_label


class neural_netowrk(tf.keras.Model):
    def __init__(self, seed=1):
        super(neural_netowrk, self).__init__()
        # use random seed to make the initialization repeat
        tf.random.set_seed(seed)
        # define convolutional layers
        self.c1 = tf.keras.layers.Conv2D(6, kernel_size=5, activation='relu', name='c1')
        self.m1 = tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2, name='m1')
        self.c2 = tf.keras.layers.Conv2D(16, kernel_size=5, activation='relu', name='c2')
        self.m2 = tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=2, name='m2')
        self.c3 = tf.keras.layers.Conv2D(64, kernel_size=3, activation='relu', name='c3')
        self.fc1 = tf.keras.layers.Dense(120, activation='sigmoid')
        self.fc2 = tf.keras.layers.Dense(84, activation='sigmoid')
        self.fc3 = tf.keras.layers.Dense(10, activation='softmax')

    def call(self, input):
        '''
        here we define the forward function
        :param input: the input data
        :return: output tensor
        '''
        # For each layer, a bias will also be initialized and add to the output after matrix multiply.
        x = self.c1(input)
        x = self.m1(x)
        x = self.c2(x)
        x = self.m2(x)
        #x = self.c3(x)
        x = tf.keras.layers.Flatten()(x)
        #x = tf.keras.layers.Dropout(rate=0.2)(x)
        x = self.fc1(x)
        x = self.fc2(x)
        output = self.fc3(x)
        return output


def main():

    N_tr = 5000
    N_ts = 500
    aug_ratio=2000/N_tr

    train_data = np.load('cifar10/training_data.npy')
    test_data = np.load('cifar10/testing_data.npy')

    train_label = np.load('cifar10/training_label.npy')
    test_label = np.load('cifar10/testing_label.npy')

    lm1=0.8
    for i in range(10):
        #ind = variane_9(train_data, train_label, Nm)

        #train_index, test_index = sample_data(train_data, train_label, test_data, test_label, ind, Nm)

        train_data0, train_label0, test_data0, test_label0 = generate_data2(lm1, N_tr, N_ts, train_data,
                                                                            train_label, test_data, test_label)

        # define training  parameters
        batch_size = 500
        learning_rate = 0.01
        epochs = 10
        display_step = 1

        train_error = []
        test_error = []
        test_accu = []
        train_loss = []
        test_loss = []
        index_list = list(range(0, train_data0.shape[0]))
        batch_num = int(train_data0.shape[0] / batch_size)
        test_num = int(test_data0.shape[0])
        optimizer = tf.keras.optimizers.Adam(learning_rate=5e-4)
        compute_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
        loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
        test_acc = []
        ls0 = np.inf
        model = neural_netowrk()
        model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
        model.fit(x=train_data0, y=train_label0,
                  batch_size=50, validation_split=0.1, epochs=20)
        result = model.evaluate(test_data0, test_label0)

        ypred=model.predict(test_data0, verbose=1)
        loss=tf.keras.losses.SparseCategoricalCrossentropy()

        ypred=np.argmax(ypred, axis=1)
        cm=confusion_matrix(test_label0, ypred)
        acc = cm[1, 1] / (cm[1, 0] + cm[1, 1])
        print("group 2 accuracy = %f" % acc)



if __name__ == '__main__':
    main()
